import pandas as pd
from datetime import datetime
import requests
import json
from transformers import BertForSequenceClassification, BertTokenizer, TrainingArguments, Trainer
import torch
import tensorflow as tf


# ----- user settings -----

train_model = False  # set to True to finetune classification model from scratch
api_key = "" # set to your manifesto api key

# ----- define important meta variables -----

recoding_dict = {"101": "1", "102": "1", "103.1": "1", "103.2": "1", "106": "1", "107": "1", "109": "1",
                 "108": "2", "110": "2",
                 "104": "3", "105": "3",
                 "201.1": "4", "201.2": "4",
                 "202.1": "5", "202.2": "5", "202.3": "5", "202.4": "5", "203": "5", "204": "5",
                 "301": "6", "302": "6", "303": "6", "304": "6",
                 "305.1": "7", "305.2": "7", "305.3": "7", "305.4": "7", "305.5": "7", "305.6": "7",
                 "401": "8", "402": "8", "403": "8", "404": "8", "405": "8", "406": "8", "407": "8", "408": "8",
                 "409": "8", "410": "8", "412": "8", "413": "8", "414": "8", "415": "8", "416.1": "8", "704": "8",
                 "411": "9",
                 "416.2": "10", "501": "10",
                 "502": "11",
                 "503": "12", "705": "12",
                 "504": "13", "505": "13", "706": "13",
                 "506": "14", "507": "14",
                 "601.1": "15", "602.1": "15", "603": "15", "604": "15", "606.1": "15", "606.2": "15",
                 "601.2": "16", "602.2": "16", "607.1": "16", "607.2": "16", "607.3": "16", "608.1": "16",
                 "608.2": "16", "608.3": "16",
                 "605.1": "17", "605.2": "17",
                 "701": "18", "702": "18",
                 "703.1": "19", "703.2": "19",
                 "000": "20"}

topic_label_dict = {"1": "Foreign_Affairs",
                    "2": "European_Union",
                    "3": "Defense",
                    "4": "Freedom",
                    "5": "Democracy",
                    "6": "Political_System",
                    "7": "Political_Authority",
                    "8": "Economy",
                    "9": "Technology_and_Infrastructure",
                    "10": "Environment",
                    "11": "Culture",
                    "12": "Equality",
                    "13": "Welfare_State",
                    "14": "Education",
                    "15": "Society_and_Values",
                    "16": "Immigration",
                    "17": "Law_and_Order",
                    "18": "Labour",
                    "19": "Agriculture",
                    "20": "NA"}

# create category-index mapping dicts
cmp_codes = [str(cmp_code) for cmp_code in range(1, 21)]
index_to_cmp = {index: str(cmp_code) for index, cmp_code in enumerate(cmp_codes)}
cmp_to_index = {index_to_cmp[index]: index for index in index_to_cmp}

# ----- create custom torch dataset classes for labeled and unlabeled data -----

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


class MyDatasetRaw(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.sequences.items()}
        return item

    def __len__(self):
        return len(self.sequences['input_ids'])

# ----- load and prepare press release dataset -----

# load press release dataset
data_path = "Corpus_Pressreleases_20210401_20210926.csv"
pressrelease_df = pd.read_csv(data_path, encoding="cp1252", index_col=0)

# convert date column to datetime
pressrelease_df["date"] = pd.to_datetime(pressrelease_df["date"])

# add campaign column to dataset
election_time = datetime.strptime("01.07.2021", '%d.%m.%Y')
pressrelease_df["period"] = [date >= election_time for date in pressrelease_df["date"]]
pressrelease_df["period"] = pressrelease_df["period"].replace({True: "campaign", False: "non-campaign"})

# add type column to dataset
pressrelease_df["type"] = ["termin" in text.lower() for text in pressrelease_df["text"]]
pressrelease_df["type"] = pressrelease_df["type"].replace({True: "not policy-related", False: "policy-related"})

# move period column
pressrelease_df.insert(6, 'period', pressrelease_df.pop('period'))

# split datasets in policy- and not policy-related versions
pressrelease_non_policy_df = pressrelease_df[pressrelease_df["type"] == "not policy-related"]
pressrelease_policy_df = pressrelease_df[pressrelease_df["type"] == "policy-related"]

# add issue variable to dataset
pressrelease_non_policy_df["issue"] = "NA"

# ----- load manifesto corpus data -----

# set api meta variables
root_url = "https://manifesto-project.wzb.eu/tools/"
dataset_version = "MPDS2022a"

# get manifesto main dataset via api
core_dataset_link = f"{root_url}api_get_core.json?key={dataset_version}&api_key={api_key}"
response = requests.get(core_dataset_link)

# convert json data to pandas dataframe
core_dataset = json.loads(response.content.decode('utf-8'))
core_dataset = pd.DataFrame(core_dataset)
core_dataset.columns = core_dataset.iloc[0]
core_dataset.drop(index=0, inplace=True)

# create unique manifesto_id column
core_dataset["manifesto_id"] = [f"{row['party']}_{row['date']}" for index, row in core_dataset.iterrows()]

# keep only manual 5 coded documents
core_dataset.drop(core_dataset[core_dataset["manual"] != "5"].index, inplace=True)

# get all contained manifesto_ids
relevant_manifesto_ids = core_dataset["manifesto_id"].unique()

# get corpus texts of relevant manifestos via manifesto api
text_and_annotations_link = f"{root_url}api_texts_and_annotations.json"
response = requests.post(text_and_annotations_link,
                         data={"keys[]": relevant_manifesto_ids, "version": "2022-1", "api_key": api_key})
manifestos = json.loads(response.content.decode('utf-8'))

# collect all manifesto texts as separate list entries
collect_manifestos = []
for manifesto in manifestos["items"]:
    annotated_text = manifesto["items"]
    annotated_text = {index: entry for index, entry in enumerate(annotated_text)}
    annotated_text_df = pd.DataFrame.from_dict(annotated_text, orient="index")
    if "content" in annotated_text_df.columns:
        annotated_text_df.rename(columns={"content": "text"}, inplace=True)
    annotated_text_df["manifesto_id"] = manifesto["key"]
    annotated_text_df["party"] = manifesto["key"].split("_")[0]
    annotated_text_df["date"] = manifesto["key"].split("_")[1]
    country_code = manifesto["key"].split("_")[0][:-3]
    if annotated_text_df[annotated_text_df["cmp_code"] != "NA"].empty:
        continue
    collect_manifestos.append(annotated_text_df)

# concat whole list of manifesto texts to single dataframe (one row = one quasi-sentence)
corpus_df = pd.concat(collect_manifestos, ignore_index=True, sort=False)

# create country_code column in corpus dataframe
corpus_df["country_code"] = [str(row["party"])[:2] for index, row in corpus_df.iterrows()]

# recode manifesto lables to our adapted coding scheme
corpus_df = corpus_df.dropna(subset=["cmp_code"])
corpus_df['cmp_code_recoded'] = corpus_df['cmp_code'].replace(recoding_dict)
corpus_df['cmp_code_recoded'].replace({"H": None, "": None, "NA": None}, inplace=True)
corpus_df = corpus_df.dropna(subset=["cmp_code_recoded"])

# remove rows that are not convertable to new coding scheme
corpus_df = corpus_df[corpus_df['cmp_code_recoded'].isin(topic_label_dict.keys())]

# ----- create training and validation data -----

# load multilingual BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# tokenize quasi-sentences in corpus dataframe
training_sequences = tokenizer(list(corpus_df["text"]), truncation=True, padding=True, max_length=512)

# convert codes to indexes from index dict
training_labels = [cmp_to_index[cmp_code] for cmp_code in corpus_df["cmp_code_recoded"]]

# create custom pytorch dataset object from tokenized sequences and labels
train_dataset = MyDataset(training_sequences, training_labels)

# split dataset into 90% train and 10% validation parts
train_dataset, valid_dataset = torch.utils.data.random_splitrandom_split(train_dataset,
                                                                         (round(len(train_dataset) * 0.9),
                                                                          round(len(train_dataset) * 0.1)))

# ----- training/loading model -----

if train_model:
    # load initial multilingual BERT model
    model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased",
                                                                       num_labels=len(cmp_codes))
    # set training parameters
    training_args = TrainingArguments(
        num_train_epochs=1,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=4,
        learning_rate=3e-5,
        warmup_steps=100,
        weight_decay=0.01,
        output_dir='./results',
        logging_dir='./logs',
        logging_steps=2000,
        evaluation_strategy="epoch",
        save_strategy="no"
    )

    # train model on manifesto data
    trainer = Trainer(model=model,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=valid_dataset)

    trainer.train()
else:
    model = BertForSequenceClassification.from_pretrained("bert-multiling_512tokens")
    trainer = Trainer(model=model)

# ----- classify press releases -----

# tokenize press releases
pressrelease_policy_df['text'] = pressrelease_policy_df['text'].fillna("")
pressrelease_document_sequences = tokenizer(list(pressrelease_policy_df["text"]),
                                            truncation=True,
                                            padding=True,
                                            max_length=512)

# create torch dataset object from unlabeled press release data
pressrelease_whole_document_dataset = MyDatasetRaw(pressrelease_document_sequences)

# predict press release data
predicted_results_pressrelease_documents = trainer.predict(pressrelease_whole_document_dataset)

# get index for highest probability prediction for every text
predicted_labels_pressreleases_documents = predicted_results_pressrelease_documents.predictions.argmax(-1)

# convert predicted index to topic lable
predicted_labels_pressreleases_documents = predicted_labels_pressreleases_documents.flatten().tolist()
predicted_labels_pressreleases_documents = [index_to_cmp[l] for l in predicted_labels_pressreleases_documents]
predicted_labels_text_pressreleases_documents = [topic_label_dict[l] for l in predicted_labels_pressreleases_documents]

# add issue classification column to press release data set
pressrelease_policy_df["issue"] = predicted_labels_text_pressreleases_documents

# get probabilities for every topic for every text
probabilities_pressreleases_documents = tf.math.softmax(predicted_results_pressrelease_documents.predictions, axis=-1).numpy()
prediction_df_pressreleases_documents = pd.DataFrame(probabilities_pressreleases_documents)
prediction_df_pressreleases_documents.columns = [f"Prob_{topic_label_dict[str(i)]}_{i}" for i in range(1,21)]  # set text label as column names

# set political authority to 0 probability
prediction_df_pressreleases_documents['Prob_Political_Authority_7'] = 0

# set index from test df as index for prediction df
prediction_df_pressreleases_documents.index = pressrelease_policy_df.index

# get updated topic predictions
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents.idxmax(axis=1)
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('Prob_', '', regex=True)
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('_\d\d', '', regex=True)
prediction_df_pressreleases_documents["issue"] = prediction_df_pressreleases_documents["issue"].replace('_\d', '', regex=True)

# merge prediction df
final_prediction_df_pressreleases_documents = pd.concat([pressrelease_policy_df,prediction_df_pressreleases_documents], axis = 1)
final_pressrelease_df = pressrelease_non_policy_df.append(final_prediction_df_pressreleases_documents)
final_pressrelease_df = final_pressrelease_df.sort_index()

# save final prediction
final_pressrelease_df.to_csv("Classified_Pressreleases_20210401_20210926.csv", index=False, encoding="utf-8")